import torch
from torch.utils.data import DataLoader
import os
import argparse
import torch_optimizer as optim
from tensorboardX import SummaryWriter
from utils.trainer import Trainer
import torch.nn as nn
from datasets.ad_ds import AD_Dataset, load_data
import os
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.model_selection import StratifiedKFold
import time
import copy
import warnings
import numpy as np

from models.loss.focal_loss import FocalLoss
from models.loss.contrastive_loss2 import Contrastive_Loss
warnings.filterwarnings("ignore")
if __name__ == "__main__":
    # python -m visdom.server
    # mp.set_start_method('spawn')

    torch.backends.cudnn.benchmark = True
    parser = argparse.ArgumentParser()
    parser.add_argument("--epoch", type=int, default=150, help="epoch")
    parser.add_argument("--batch_size", type=int, default=128, help="batch size")
    parser.add_argument("--learning_rate", type=float, default=0.001, help="learning_rate")
    parser.add_argument("--log_path", type=str, default='log/tensorboard/',
                        help="log_path")
    parser.add_argument("--data_path", type=str, default='./train/train',
                        help="data_path")
    parser.add_argument("--label_path", type=str, default=r'./train/train_open.csv',
                        help="label_path")
    parser.add_argument("--data_url", type=str, default='',
                        help="data_url")
    parser.add_argument("--train_url", type=str, default='',
                        help="train_url")
    parser.add_argument("--log_url", type=str, default='',
                        help="log_url")
    parser.add_argument("--init_method", type=str, default='',
                        help="init_method")
    parser.add_argument("--save_name", type=str, default='dnn0_5layers_focal_1',
                        help="save_name")
    parser.add_argument("--num_gpus", type=int, default=1,
                        help="num_gpus")
    # parser.add_argument("--save_name", type=str, default='dnn_residual_focal_2',
    #                     help="save_name")
    # parser.add_argument("--save_name", type=str, default='dnn1_5_layers_focal_06242216',
    #                     help="save_name")
    # parser.add_argument("--model", type=str, default='cbam18',
    #                     help="cbam18,resnet18,effnetb4,ca18,cbam34,cbam50")
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    opt = parser.parse_args()
    print(str(opt))
    # os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
    BATCH_SIZE = opt.batch_size
    EPOCH = opt.epoch
    learning_rate = opt.learning_rate
    pretrain_w_path = opt.pretrain_weight_path if 'pretrain_weight_path' in opt else ''
    n_samples = BATCH_SIZE * 20

    writer = SummaryWriter(os.path.join(opt.log_path, opt.save_name), comment=opt.save_name,
                           flush_secs=2)
    save_path = os.path.join('save', opt.save_name)

    if not os.path.exists(opt.data_path):
        import moxing as mox
        mox.file.copy_parallel(opt.data_url, './train/')
        print('数据已加载')

    x, y = load_data(opt.data_path,opt.label_path)
    x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)
    mean = np.mean(x, axis=0)
    std = np.std(x, axis=0)
    x = (x - mean) / std
    x = np.nan_to_num(x, nan=0.0, posinf=0, neginf=0)

    folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=2021).split(x, y)

    from models.dnn0 import DNN
    # from models.dnn_1 import DNN
    # from models.dnn_residual import DNN
    init_model = DNN(28169, 4096, 512, 3, dropout_p=0.4)
    # init_model = DNN(28169)



    # sampler = WeightedRandomSampler(weights=train_data.sample_weight, num_samples=n_samples,
    #                                 replacement=True)

    loss_list = []

    max_auc_list = []
    chkp_list = []


    for fold, (trn_idx, val_idx) in enumerate(folds):
        print('------------------Fold %i--------------------' % fold)
        train_data = AD_Dataset(x, y,trn_idx,device)
        val_data = AD_Dataset(x, y,val_idx,device)

        train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True)
        val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8,
                                      pin_memory=True)  # 使用DataLoader加载数据

        model = copy.deepcopy(init_model)
        model = nn.DataParallel(model)
        model = model.to(device)
        optimizer = optim.RAdam(
            model.parameters(),
            lr=1e-3,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0,
        )
        # criterion = nn.CrossEntropyLoss()
        criterion = FocalLoss(gamma=2)
        # criterion = Contrastive_Loss(smoothing_value=0.1)
        trainer = Trainer(model, optimizer, criterion, train_data_loader, val_data_loader,device, epoch=EPOCH)
        save_name = os.path.join(save_path,'f%i' % (fold))
        if not os.path.exists(save_name):
            os.makedirs(save_name)
        min_val_loss, max_val_auc = trainer.train(save_name,fold)
        # print('Fold' + str(fold), min_val_loss)
        # print('Fold' + str(fold), max_val_auc)
        # max_auc_list.append(max_val_auc)
        # chkp_list.append(save_name)

    if opt.train_url !='' :
        if '/home/' not in opt.train_url:
            import moxing as mox
            # from deep_moxing.model_analysis.api import analyse, tmp_save

            # model_path = 'obs://ad-competiton/my_baseline/model'
            train_url = 'obs:'+opt.train_url.replace('s3:','')
            data_url = 'obs:'+opt.data_url.replace('s3:','')
            log_url = 'obs://'+opt.log_url
            print('Start to save model to',train_url,'from',save_path)
            np.save('./mean.npy', mean)
            np.save('./std.npy', std)
            mox.file.copy('./mean.npy', train_url + '/mean.npy')
            mox.file.copy('./std.npy', train_url + '/std.npy')
            mox.file.copy('./std.npy', train_url + '/std.npy')
            mox.file.copy_parallel(save_path,train_url)
            mox.file.copy_parallel('./log',log_url)
        else:
            from shutil import copyfile

            np.save(os.path.join(opt.train_url,'/mean.npy'), mean)
            np.save(os.path.join(opt.train_url,'/std.npy'), std)
            copyfile(save_path,opt.train_url)